package com.common.bean.tokenstore;

import cn.hutool.cache.CacheUtil;
import cn.hutool.cache.impl.TimedCache;
import cn.hutool.core.collection.CollUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.TokenStore;

import java.util.Collection;
import java.util.Collections;

/**
 * token本地缓存
 */
@Slf4j
public class BaseLocalTokenStore implements TokenStore {

    private static final String ACCESS = "access:";
    private static final String AUTH_TO_ACCESS = "auth_to_access:";
    private static final String AUTH = "auth:";
    private static final String REFRESH_AUTH = "refresh_auth:";
    private static final String ACCESS_TO_REFRESH = "access_to_refresh:";
    private static final String REFRESH = "refresh:";
    private static final String REFRESH_TO_ACCESS = "refresh_to_access:";
    private static final String CLIENT_ID_TO_ACCESS = "client_id_to_access:";
    private static final String UNAME_TO_ACCESS = "uname_to_access:";


    private static int expiration = 604800 * 1000;


    /**
     * token 缓存对象 默认24小时
     */
    protected static TimedCache<String, Object> TOKEN_CACHE = CacheUtil.newTimedCache(expiration);

    private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();

    public void setAuthenticationKeyGenerator(AuthenticationKeyGenerator authenticationKeyGenerator) {
        this.authenticationKeyGenerator = authenticationKeyGenerator;
    }

    @Override
    public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
        String key = authenticationKeyGenerator.extractKey(authentication);
        OAuth2AccessToken accessToken = (OAuth2AccessToken) TOKEN_CACHE.get(AUTH_TO_ACCESS + key);
        if (accessToken != null
                && !key.equals(authenticationKeyGenerator.extractKey(readAuthentication(accessToken.getValue())))) {
            storeAccessToken(accessToken, authentication);
        }
        return accessToken;
    }

    @Override
    public OAuth2Authentication readAuthentication(OAuth2AccessToken token) {
        return readAuthentication(token.getValue());
    }

    @Override
    public OAuth2Authentication readAuthentication(String token) {
        return (OAuth2Authentication) TOKEN_CACHE.get(AUTH + token);
    }

    @Override
    public OAuth2Authentication readAuthenticationForRefreshToken(OAuth2RefreshToken token) {
        return readAuthenticationForRefreshToken(token.getValue());
    }

    public OAuth2Authentication readAuthenticationForRefreshToken(String token) {
        return (OAuth2Authentication) TOKEN_CACHE.get(REFRESH_AUTH + token);
    }

    @Override
    public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
        int seconds = token.getExpiration() != null ? token.getExpiresIn() * 1000 : expiration;

        TOKEN_CACHE.put(ACCESS + token.getValue(), token, seconds);
        TOKEN_CACHE.put(AUTH + token.getValue(), authentication, seconds);
        TOKEN_CACHE.put(AUTH_TO_ACCESS + authenticationKeyGenerator.extractKey(authentication), token, seconds);
        if (!authentication.isClientOnly()) {
            TOKEN_CACHE.put(UNAME_TO_ACCESS + getApprovalKey(authentication), token, seconds);
        }

        TOKEN_CACHE.put(CLIENT_ID_TO_ACCESS + authentication.getOAuth2Request().getClientId(), token, seconds);

        if (token.getRefreshToken() != null && token.getRefreshToken().getValue() != null) {
            TOKEN_CACHE.put(REFRESH_TO_ACCESS + token.getRefreshToken().getValue(), token.getValue(), seconds);
            TOKEN_CACHE.put(ACCESS_TO_REFRESH + token.getValue(), token.getRefreshToken().getValue(), seconds);
        }
    }

    private String getApprovalKey(OAuth2Authentication authentication) {
        String userName = authentication.getUserAuthentication() == null ? "" : authentication.getUserAuthentication()
                .getName();
        return getApprovalKey(authentication.getOAuth2Request().getClientId(), userName);
    }

    private String getApprovalKey(String clientId, String userName) {
        return clientId + (userName == null ? "" : ":" + userName);
    }

    @Override
    public void removeAccessToken(OAuth2AccessToken accessToken) {
        deleteAccessToken(accessToken.getValue());
    }

    @Override
    public OAuth2AccessToken readAccessToken(String tokenValue) {
        return (OAuth2AccessToken) TOKEN_CACHE.get(ACCESS + tokenValue);
    }

    public void deleteAccessToken(String tokenValue) {
        OAuth2AccessToken deleted = (OAuth2AccessToken) TOKEN_CACHE.get(ACCESS + tokenValue);
        // caller to do that
        OAuth2Authentication authentication = (OAuth2Authentication) TOKEN_CACHE.get(AUTH + tokenValue);

        TOKEN_CACHE.remove(AUTH + tokenValue);
        TOKEN_CACHE.remove(ACCESS + tokenValue);
        TOKEN_CACHE.remove(ACCESS_TO_REFRESH + tokenValue);

        if (authentication != null) {
            TOKEN_CACHE.remove(AUTH_TO_ACCESS + authenticationKeyGenerator.extractKey(authentication));

            String clientId = authentication.getOAuth2Request().getClientId();
            TOKEN_CACHE.remove(UNAME_TO_ACCESS + getApprovalKey(clientId, authentication.getName()));

            TOKEN_CACHE.remove(CLIENT_ID_TO_ACCESS + clientId);

            TOKEN_CACHE.remove(AUTH_TO_ACCESS + authenticationKeyGenerator.extractKey(authentication));
        }
    }

    @Override
    public void storeRefreshToken(OAuth2RefreshToken refreshToken, OAuth2Authentication authentication) {
        TOKEN_CACHE.put(REFRESH + refreshToken.getValue(), refreshToken);
        TOKEN_CACHE.put(REFRESH_AUTH + refreshToken.getValue(), authentication);
    }

    @Override
    public OAuth2RefreshToken readRefreshToken(String tokenValue) {
        return (OAuth2RefreshToken) TOKEN_CACHE.get(REFRESH + tokenValue);
    }

    @Override
    public void removeRefreshToken(OAuth2RefreshToken refreshToken) {
        deleteRefreshToken(refreshToken.getValue());
    }

    public void deleteRefreshToken(String tokenValue) {
        TOKEN_CACHE.remove(REFRESH + tokenValue);
        TOKEN_CACHE.remove(REFRESH_AUTH + tokenValue);
        TOKEN_CACHE.remove(REFRESH_TO_ACCESS + tokenValue);
    }

    @Override
    public void removeAccessTokenUsingRefreshToken(OAuth2RefreshToken refreshToken) {
        deleteAccessTokenUsingRefreshToken(refreshToken.getValue());
    }

    private void deleteAccessTokenUsingRefreshToken(String refreshToken) {

        String token = (String) TOKEN_CACHE.get(REFRESH_TO_ACCESS + refreshToken);

        if (token != null) {
            TOKEN_CACHE.remove(ACCESS + token);
        }
    }

    @Override
    public Collection<OAuth2AccessToken> findTokensByClientIdAndUserName(String clientId, String userName) {
        OAuth2AccessToken result = (OAuth2AccessToken) TOKEN_CACHE.get(UNAME_TO_ACCESS + getApprovalKey(clientId, userName));

        if (result == null) {
            return Collections.emptySet();
        }

        return Collections.unmodifiableCollection(CollUtil.newArrayList(result));
    }

    @Override
    public Collection<OAuth2AccessToken> findTokensByClientId(String clientId) {
        OAuth2AccessToken result = (OAuth2AccessToken) TOKEN_CACHE.get(CLIENT_ID_TO_ACCESS + clientId);

        if (result == null) {
            return Collections.emptySet();
        }
        return Collections.unmodifiableCollection(CollUtil.newArrayList(result));
    }
}